"""The program includes several functions: setting a random seed, 
loading data from a JSON file, batching data, and extracting answers from generated text.
"""

import random
import numpy as np
import torch
import json
import re
def set_random_seed(seed: int):
    """
    Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`.

    Parameters
    ------------
    seed : int
        The default seed.
        
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def load_data(file_name: str):
    """
    Load data with file name.

    Parameters
    ------------
    file_name : str.
        The dataset file name.
    
    Returns
    ------------
    inputs : list.
        The input texts of the dataset.
    outputs : list.
        The output texts file datasets.    
    len : int.
        The length of the dataset.
    """
    inputs = []
    outputs = []
    type = ""
    with open(file_name, encoding='utf-8') as f:
        json_data = json.load(f)
        type = json_data["type"]
        for line in json_data["instances"]:
            inputs.append(line["input"])
            outputs.append(line["output"])

    print(f"load dataset {file_name} success.\n")
    print(f"Type : {type}, datasize : {len(outputs)}")

    return inputs, outputs, len(outputs)

def batchlize(examples: list, batch_size: int, random_shuffle: bool):
    """
    Convert examples to a dataloader.

    Parameters
    ------------
    examples : list.
        Data list.
    batch_size : int.

    random_shuffle : bool
        If true, the dataloader shuffle the training data.
    
    Returns
    ------------
    dataloader:
        Dataloader with batch generator.
    """
    size = 0
    dataloader = []
    length = len(examples)
    if (random_shuffle):
        random.shuffle(examples)
    while size < length:
        if length - size > batch_size:
            dataloader.append(examples[size : size+batch_size])
            size += batch_size
        else:
            dataloader.append(examples[size : size+(length-size)])
            size += (length - size)
    return dataloader



def answer_extraction(response, answer_type=None):   #use this funtion to extract answers from generated text

    """
    Use this funtion to extract answers from generated text

    Parameters
    ------------
    args : 
        Arguments.
    response : str
        plain string response.


    Returns
    ------------
    answer:
        Decoded answer (such as A, B, C, D, E for mutiple-choice QA).
    """

    # temp = response["generated_text"]
    temp = response
    if answer_type in ("gsm8k", "svamp", "asdiv", "addsub", "singleeq", "multiarith", "math"):
        temp = temp.replace(",", "")
        temp = [s for s in re.findall(r'-?\d+\.?\d*', temp)]
    elif answer_type in ("aqua", "csqa", "multiple_choice"):
        temp = re.findall(r'A|B|C|D|E', temp)
    elif answer_type in ("strategyqa", "coin_flip"):
        temp = temp.lower()
        temp = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", temp)
        temp = temp.split(" ")
        temp = [i for i in temp if i in ("yes", "no")]
    elif answer_type in ("last_letters"):
        temp = re.sub("\"|\'|\n|\.|\s","", temp)
        temp = [temp]
    elif answer_type in ("pubmedqa", "binary_choice"):
        # pattern = "Output: (yes|no|maybe)"
        # sttr = re.search(pattern, temp)
        # answer = sttr.group(0)[8:] if sttr is not None else "N/A"
        pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)"
        sttr = re.search(pattern, temp)
        if sttr is not None:
            mid_answer = sttr.group(0)
            mid_answer = mid_answer.split(":")[-1].strip()
            answer = mid_answer.lower()
        else:
            pattern = "(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)(\.|\s)"
            sttr = re.search(pattern, temp)
            if sttr is not None:
                answer = sttr.group(0)[:-1].lower()
            else:
                answer = "N/A"
        return answer
    elif answer_type == "medmcqa":
        # pattern = "Output: (A|B|C|D)."
        # sttr = re.search(pattern, temp)
        # answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A"
        pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)"
        sttr = re.search(pattern, temp)
        if sttr is not None:
            mid_answer = sttr.group(0)
            answer = mid_answer[-1].lower()
        else:
            pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)"
            sttr = re.search(pattern, temp)
            if sttr is not None:
                if '(' in sttr.group(0):
                    answer = sttr.group(0)[1].lower()
                else:
                    answer = sttr.group(0)[0].lower()
            else:
                answer = "N/A"
        return answer

    elif answer_type == "usmle":
        # pattern = "Output: (A|B|C|D)."
        # sttr = re.search(pattern, temp)
        # answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A"
        pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)"
        sttr = re.search(pattern, temp)
        if sttr is not None:
            mid_answer = sttr.group(0)
            answer = mid_answer[-1].lower()
        else:
            pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)"
            sttr = re.search(pattern, temp)
            if sttr is not None:
                if '(' in sttr.group(0):
                    answer = sttr.group(0)[1].lower()
                else:
                    answer = sttr.group(0)[0].lower()
            else:
                answer = "N/A"
        return answer
    elif answer_type == "text":
        return response
    else:
        raise NotImplementedError(f"Unsupported answer type: {answer_type}")

    if len(temp) != 0:
        answer = temp[-1]
        # if there is . at the end of answer, remove it
        # e.g. answer = 64.
        if answer != "":
            if answer[-1] == ".":
                answer = answer[:-1]

            # round the answer to nearest integer
        if answer_type in ("gsm8k", "svamp"):
            try:
                answer = str(round(float(answer)))
            except:
                answer = "" # no sol or sol doesn't have valid format
        elif answer_type in ("last_letters"):
            try:
                answer = answer[-args.concat_length:]
            except:
                answer = ""
    else:
        answer = ""
    return answer


def process_image_flag(text, image_flag="<ImageHere>"):
    texts = text.split(image_flag)
    if len(texts) > 1:
        image_token_indexes = [len(text) for text in texts[:-1]]
    else:
        image_token_indexes = []
    # cumsun
    image_token_indexes = list(np.cumsum(image_token_indexes))
    texts = "".join(texts)
    return texts, image_token_indexes